Auto-install fused lm_head + cross_entropy forward (opt-in)#657
Conversation
Adds an opt-in (UNSLOTH_FUSED_FORWARD=1) auto-installer that rewrites
the canonical lm_head + self.loss_function triplet on every transformers
`*ForCausalLM` / `*ForConditionalGeneration` whose forward matches the
shape used from transformers 4.56 onwards. Skipping logits.float() over
(seq_len x vocab_size) avoids the OOM that surfaced in #5441 and shaves
the bf16 logits tensor as well.
Layers:
unsloth_zoo/fused_losses/forward_adapter.py
Maps the HF self.loss_function(logits=..., labels=..., vocab_size=...,
**kwargs) calling convention onto unsloth_fused_ce_loss. Pops
num_items_in_batch -> n_items, threads ignore_index / label_smoothing /
logit_softcapping / logit_scale_multiply / logit_scale_divide, and
falls back to a stock CE if the caller passes a pre-shifted
shift_labels tensor (unsupported by the chunked kernel today).
unsloth_zoo/fused_losses/ast_rewriter.py
NodeTransformer that recognises the canonical triplet:
<NAME> = self.<HEAD>(<HIDDEN_EXPR>[...])
loss = None (optional)
if labels is not None:
<LOSS> = self.loss_function(<NAME>, labels, vocab_size=..., **kwargs)
and rewrites it to call unsloth_fused_lm_head_loss(<HIDDEN_EXPR>,
self.<HEAD>, labels, ...). Tolerates keyword vs positional vocab_size,
`.float()` / `[slice]` chains around the lm_head call, and detects
logits re-binding (e.g. Cohere's `logits = logits * self.logit_scale`)
as a refuse signal so we never produce a broken forward.
unsloth_zoo/fused_losses/forward_install.py
Two-tier installer: (1) hash-allowlist fast path via
register_canonical(hash, forward_fn); (2) AST triplet rewrite.
Driven by a meta-path import hook that intercepts
transformers.models.<X>.modeling_<X> imports and patches eligible
classes as their module loads. Soft floor at transformers >= 4.56.
audit() returns a JSON-safe dict of patched / unmatched / failed
classes for observability.
Kernel updates:
unsloth_zoo/fused_losses/cross_entropy_loss.py
compute_fused_ce_loss + UnslothFusedLoss.forward now thread
ignore_index (default -100) into the label-shift step and the inner
F.cross_entropy call. compute_fused_ce_loss also accepts
label_smoothing. Matches HF ForCausalLMLoss semantics so callers
that override either no longer silently regress.
Tests (tests/test_fused_forward_install.py, 14 cases):
- AST rewriter accepts keyword form, positional vocab_size, `.float()`
wrapper. Declines non-canonical, declines on logits rebinding.
- install_for_class: noop when disabled, skips ineligible names,
patches canonical, idempotent, function-override fast path,
audit() snapshot.
- Numerical equivalence on a toy CUDA model: fused loss within
bf16 -> fp32 rounding noise of the reference.
- Kernel respects ignore_index and label_smoothing kwargs.
End-to-end equivalence on Llama-3.2-1B + alpaca-cleaned (seed 3407,
max_steps 10): identical step-1 loss + grad_norm, max |loss delta| =
0.005, max |grad_norm delta| = 0.025 across the run. Audit reported
19 classes patched, 0 failed when UNSLOTH_FUSED_FORWARD=1 (LlamaForCausalLM,
Qwen3ForCausalLM, MistralForCausalLM, Gemma2/3 / GemmaForCausalLM,
Mllama, DeepseekV3, Qwen3MoE / Qwen3Next, Bloom, FalconH1, etc.).
Off by default. Set UNSLOTH_FUSED_FORWARD=1 to opt in.
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in auto-installer for fused lm_head and cross_entropy losses, utilizing an AST-level rewriter and import hooks to patch transformers models. The implementation includes updates to the fused cross-entropy kernel to support ignore_index and label_smoothing, along with a comprehensive test suite. Review feedback suggests refining the exception handling during installation to improve visibility and reconsidering the aggressive stripping of decorators in the AST rewriter to avoid potential side effects on model logic.
| try: | ||
| from .fused_losses.forward_install import install_modeling_import_hook as _install_fused_forward | ||
| _install_fused_forward() | ||
| del _install_fused_forward | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
The broad try...except Exception: pass block around the fused forward installation can make debugging difficult if the installer fails for unexpected reasons. It is recommended to at least print the exception to aid in troubleshooting, especially for visibility in 'studio' environments, as this is an opt-in feature that users might want to verify.
References
- Use print instead of logger.info for messages that must be visible in 'studio' when working with llama.cpp, as logger.info messages may be filtered out.
| # that may not have them visible. The decorators only add docstring | ||
| # sugar / tuple-return handling and are not needed for the runtime | ||
| # forward we install. | ||
| fn.decorator_list = [] |
There was a problem hiding this comment.
Stripping all decorators from the rewritten forward function might lead to unexpected behavior if the model relies on functional decorators (e.g., for compilation hints, custom logic, or hooks). While many transformers decorators are docstring-related, a more selective approach or a clear justification for why this is safe for all supported models would be preferable. If the intent is to avoid issues with decorators not being present in the exec namespace, note that ns is already initialized with the original function's globals.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 680c9a3788
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| ) | ||
| {logits} = EMPTY_LOGITS | ||
| else: | ||
| {logits} = self.{head_attr}({hidden_src}) |
There was a problem hiding this comment.
Preserve wrappers in the labels-none branch
When the matched original assignment is wrapped, e.g. lm_logits = self.lm_head(hidden_states).float() as covered by the new positional test fixture, the rewritten labels is None path reconstructs only self.<head>(hidden) and drops the wrapper. That means opt-in patched models return different generation/eval logits (dtype/device/view transformations such as .float(), .contiguous(), or post-call slicing are lost) even though the installer claims the no-labels path is untouched; store and reuse the full original RHS for the else branch or decline these matches.
Useful? React with 👍 / 👎.
| # that may not have them visible. The decorators only add docstring | ||
| # sugar / tuple-return handling and are not needed for the runtime | ||
| # forward we install. | ||
| fn.decorator_list = [] |
There was a problem hiding this comment.
Keep tuple-return decorators on rewritten forwards
Stripping all decorators removes runtime behavior from Transformers forwards, not just docstring sugar. In current Transformers, can_return_tuple pops/uses return_dict and converts a ModelOutput to output.to_tuple() when return_dict=False; after this rewrite, patched classes with that decorator will ignore the standard return_dict=False API and return a dataclass instead of a tuple. Since the exec namespace is copied from the original globals, preserve/reapply runtime decorators such as can_return_tuple or only strip known documentation-only decorators.
Useful? React with 👍 / 👎.
|
|
||
| from __future__ import annotations | ||
|
|
||
| import os |
| from __future__ import annotations | ||
|
|
||
| import os | ||
| import sys |
|
|
||
| import os | ||
| import sys | ||
| import types |
Forwards routed through unsloth_compiled_cache see __globals__ for the cached module, which does not always re-import the HF output dataclass the original modeling file referenced (e.g. Gemma3ForCausalLM's return statement uses CausalLMOutputWithPast). Populate the exec namespace with everything from transformers.modeling_outputs as a fallback so the rewritten forward links cleanly. Caught during multi-model equivalence run (Gemma3-1B fused) which now matches the stock path bit-for-bit alongside Llama, Qwen3, Phi3, and Mistral.
Multi-model equivalence runRan the same
Notes:
Also fixed a Artifacts:
|
|
This PR appears to address open issue(s). The duplicate detector matched the following open issues with HIGH confidence:
If this PR fixes any of them, consider adding |
forward_adapter.py - shift_labels fallback now uses reduction=sum and divides by n_items when num_items_in_batch is supplied, matching HF ForCausalLMLoss gradient-accumulation scaling. - shift_labels=False (bool) now routes to the same stock-CE fallback instead of leaking through to the always-shifting fused kernel. - Removed redundant inner import torch. cross_entropy_loss.py - Promote a non-tensor n_items divisor (HF trainers pass a Python int via gradient accumulation) to a scalar tensor before the existing DataParallel .numel()/.ravel() guard, which is preserved verbatim. Without the promotion an int n_items raises AttributeError inside the autograd forward. ast_rewriter.py - Capture the full lm_head RHS (e.g. .float()/.contiguous()/[slice]) and emit it in the else-branch so the inference path keeps its original dtype/shape semantics. - Only strip docstring-only decorators (auto_docstring, add_start_docstrings*, add_end_docstrings, replace_return_docstrings); @can_return_tuple carries return_dict=False semantics and stays. - Reject forwards with non-empty else, multi-statement labels branches, or aliased labels arguments (CSM-style depth-decoder loss survives intact rather than being silently dropped). - Reject forwards where any statement between the lm_head assign and the labels-if mutates or reads logits (Gemma3 final_logit_softcapping used to be silently bypassed by the fused-loss path). - Forward explicit loss_function keywords beyond vocab_size (Bloom passes num_items_in_batch=kwargs.get(...) without a **kwargs unpack). - _find_loss_function_call / _find_loss_assign_target now inspect only the direct if-body, so a nested guard inside the labels branch is not silently dropped. forward_install.py - Drop *ForConditionalGeneration from auto-install eligibility (the fused kernel hardcodes a causal shift; aligned-label seq2seq losses would be off-by-one). - Skip composite/non-linear heads via a _LINEAR_HEAD_ATTRS allowlist so BigBird-style self.cls(...) (BigBirdOnlyMLMHead) is not patched. - install_for_class / install_for_module now also gate on the transformers version floor, matching install_modeling_import_hook. - Inject transformers.utils.generic.can_return_tuple into the exec namespace so the preserved decorator resolves at runtime.
Compress narrative docstrings and inline rationale blocks across fused_losses/* and the __init__.py opt-in stanza. Load-bearing notes (@can_return_tuple semantics, Gemma3 softcap reasoning, BigBird composite-head guard, transformers >= 4.56 floor, ForCausalLM-only eligibility) are preserved; only WHAT-restating prose was removed.
|
Auto-review verdict: Approved PR #657 adds an opt-in (UNSLOTH_FUSED_FORWARD=1) AST-level rewriter that intercepts transformers *ForCausalLM imports and routes loss through Unsloth's chunked fused lm_head+cross_entropy kernel, skipping the bf16 logits and fp32 cast for measurable VRAM savings during training. Review-hardened the rewriter against real model shapes (Bloom, CSM, Gemma3 softcap, BigBird composite head, T5Gemma2) and the adapter against integer num_items_in_batch and pre-shifted-label scaling, so it now correctly rewrites canonical forwards and cleanly refuses the rest. Reason: All 12 real issues identified during review were fixed; tests pass; remaining items were design choices or hypothetical edge cases. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c33abf4d8e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| tgt = stmt.targets[0] | ||
| if not (isinstance(tgt, ast.Name) and tgt.id == logits_name): | ||
| continue | ||
| inner = _find_inner_self_call(stmt.value) |
There was a problem hiding this comment.
Refuse transformed lm_head expressions before fusing
When a matching forward applies a logits transform in the same assignment, e.g. logits = self.lm_head(hidden_states) * self.logit_scale or an inline softcap, this still captures the inner self.lm_head call and rewrites the labels branch to call unsloth_fused_lm_head_loss on the unmodified linear output. The no-labels branch keeps the full original RHS, so only training loss silently drops the transform; this should either decline anything beyond known-safe wrappers such as .float()/.contiguous()/slicing or encode the transform in the fused call.
Useful? React with 👍 / 👎.
| from .fused_losses.forward_install import install_modeling_import_hook as _install_fused_forward | ||
| _install_fused_forward() | ||
| del _install_fused_forward | ||
| except Exception: |
Cover the eight semantic fixes that landed in commit db90fa1 so regressions are caught at test time rather than at training time: - test_ast_rewriter_declines_when_intermediate_touches_logits Gemma final_logit_softcapping between lm_head and the labels-if must not be silently bypassed. - test_ast_rewriter_declines_when_labels_aliased CSM-style `loss = self.loss_function(..., labels=backbone_labels)` on an `if labels is not None:` gate must refuse. - test_ast_rewriter_declines_non_trivial_labels_branch MoE-style auxiliary loss inside the labels branch must refuse so aux_loss + router_aux_loss_coef stays intact. - test_ast_rewriter_forwards_explicit_extra_kwargs Bloom-style `num_items_in_batch=kwargs.get(...)` without **kwargs must reach the kernel. - test_install_skips_for_conditional_generation *ForConditionalGeneration uses aligned labels; auto-install must skip. - test_install_skips_composite_head BigBird-style `self.cls(...)` composite head must hit the _LINEAR_HEAD_ATTRS allowlist and log as non-linear-head. - test_fused_kernel_accepts_int_n_items HF Trainer grad-accum passes a Python int divisor; kernel must promote it to a scalar tensor before the DataParallel guard. - test_adapter_falls_back_when_shift_labels_false `shift_labels=False` bool must route through stock CE; the fused kernel always re-shifts. All 22 tests pass (14 original + 8 new). Multi-model end-to-end equivalence rerun against the post-review tree (seed 3407, max_steps=10, alpaca-cleaned): model s1 eq max|loss d| max|grad d| n_patched Llama-3.2-1B True 0.00450 0.01276 11 Qwen3-0.6B True 0.00490 0.07686 11 Gemma-3-1B True 0.00000 0.00000 11 Mistral-7B-v.3 True 0.00370 0.03093 11 Step 1 loss + grad_norm are bitwise identical for every model; n_patched dropped from 19 -> 11 because ConditionalGeneration + Gemma2/3 (logits touched by softcap) + BigBird (composite head) are now correctly skipped.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ab283b9503
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| {hidden_src}, self.{head_attr}, labels, | ||
| vocab_size={vocab}{extra}{kwargs_unpack}, | ||
| ) | ||
| {logits} = EMPTY_LOGITS |
There was a problem hiding this comment.
Honor UNSLOTH_RETURN_LOGITS in fused forward
When this opt-in installer patches a model, any training/eval call with labels now always returns EMPTY_LOGITS, even if the user set the repository's documented UNSLOTH_RETURN_LOGITS=1 escape hatch for metric computation or logit inspection. The existing compiler path conditionally materializes logits under that env var, so this new path makes the same configured run lose logits and can break Trainer prediction/metrics flows that require them; this assignment should mirror that conditional or otherwise skip the fused shortcut when logits are requested.
Useful? React with 👍 / 👎.
is_enabled() now returns True unless UNSLOTH_FUSED_FORWARD is explicitly set to "0". Updated docstrings and the __init__.py comment to reflect the new default. The two-tier installer + LOSS_MAPPING backstop in #656 means the worst case for any class we touch is no-op (refused via _UNMATCHED or composite-head guard) -- never a worse forward than the stock path. Test suite (23 cases, was 22 + new test_install_default_is_on): all green. Refresh of the multi-model equivalence rerun with no env var set versus UNSLOTH_FUSED_FORWARD=0 (Llama-3.2-1B / Qwen3-0.6B / Gemma-3-1B / Mistral-7B-v0.3, seed 3407, max_steps=10, alpaca-cleaned): model off enabled default enabled s1 eq max|loss d| max|grad d| Llama-3.2-1B False True True 0.00410 0.02336 Qwen3-0.6B False True True 0.00680 0.02561 Gemma-3-1B False True True 0.00000 0.00000 Mistral-7B-v.3 False True True 0.00530 0.05310 Step 1 loss + grad_norm bitwise identical for every model; deltas across the run stay within bf16 -> fp32 chunked-CE rounding noise. Audit reports 11 classes patched at default and 0 patched when explicitly disabled.
| install_for_module(mod) | ||
| except Exception: | ||
| continue | ||
| _INSTALL_DONE = True |
| try: | ||
| replacement.__qualname__ = forward.__qualname__ | ||
| replacement.__module__ = forward.__module__ | ||
| except Exception: |
| try: | ||
| from transformers.utils.generic import can_return_tuple | ||
| ns.setdefault("can_return_tuple", can_return_tuple) | ||
| except Exception: |
| if _name.startswith("_"): | ||
| continue | ||
| ns.setdefault(_name, getattr(_mo, _name)) | ||
| except Exception: |
| new_forward.__qualname__ = forward.__qualname__ | ||
| new_forward.__module__ = forward.__module__ | ||
| new_forward.__doc__ = forward.__doc__ | ||
| except Exception: |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: db4e5ea3c5
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def is_enabled() -> bool: | ||
| # On by default; opt out via UNSLOTH_FUSED_FORWARD=0. | ||
| return os.environ.get("UNSLOTH_FUSED_FORWARD", "1") != "0" |
There was a problem hiding this comment.
Keep fused-forward patching opt-in by default
When UNSLOTH_FUSED_FORWARD is unset, this returns True, and unsloth_zoo.__init__ immediately calls install_modeling_import_hook(), so importing the package monkey-patches every eligible Transformers modeling module even though this change is described as opt-in. That makes existing training/eval runs pick up the new AST-rewritten forward without consent, including any edge cases that the installer does not yet handle; default this to disabled unless the env var is explicitly set to 1.
Useful? React with 👍 / 👎.
trl 1.x padding_free passes shift_labels=<tensor> through the loss
function. The adapter previously fell back to a materialised-logits
F.cross_entropy in that case, which kept the OOM problem the chunked
kernel was supposed to fix.
Plumb shift_labels through unsloth_fused_ce_loss instead. The outer
UnslothFusedLoss.forward already handles label shifting; when the
caller pre-shifted we just flatten and skip the inner re-shift.
Files:
- cross_entropy_loss.py: unsloth_fused_ce_loss gains shift_labels arg
(default True). Outer adds an else branch that flattens pre-shifted
labels so chunking aligns with hidden_states. The four inner
accumulate_chunk call sites pass False unconditionally now since
the outer always normalises labels.
- forward_adapter.py: drop the F.cross_entropy fallback. Pick (target,
do_shift) based on the shift_labels kwarg and call the fused kernel
with shift_labels=do_shift.
- test_fused_forward_install.py: rename the stale fallback test and
add five fp32-strict numerical checks (atol/rtol=1e-5):
* auto-shift matches F.cross_entropy
* pre-shifted tensor matches F.cross_entropy
* shift_labels=False matches F.cross_entropy
* num_items_in_batch divides correctly
* int and 0-d tensor n_items produce equal loss
Empirical end-to-end checks (10 step Llama-3.2-1B LoRA, max_steps=10):
trl 1.4.0 padding_free=True, fused vs off:
step 1 loss: 1.45730 == 1.45730 (exact)
max delta over 10 steps: 0.003 (bf16 noise)
num_items_in_batch wiring (batch=2, grad_accum=4):
HF passes a scalar tensor, consistent across the 4 micro-batches
in each window. n_items equals sum(non_ignore_labels) - rows in
every window (the per-row causal-shift drop), matching the
post-shift count HF uses for the mean reduction.
27/27 unit tests pass.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ec66bc6cc6
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| {loss} = unsloth_fused_lm_head_loss( | ||
| {hidden_src}, self.{head_attr}, labels, | ||
| vocab_size={vocab}{extra}{kwargs_unpack}, |
There was a problem hiding this comment.
Preserve custom loss_function dispatch before fusing
When a patched *ForCausalLM instance overrides self.loss_function (for example by setting a custom _loss_function/loss type for distillation or an experimental objective), the original forward would call that callable here, but the rewrite now always routes the labels path through unsloth_fused_lm_head_loss and computes standard fused CE. That silently changes the training objective for those models; either guard the fused branch on the default causal-LM loss at runtime or decline classes/instances whose loss_function is not the supported loss.
Useful? React with 👍 / 👎.
The fused-forward installer (forward_install.py) rewrites *ForCausalLM.forward at import time. Two upstream-pattern tests used inspect.getsource(cls.forward) and got the rewritten body, which no longer contains the canonical upstream lines compiler.py pins. Switch both probes to read the modeling module's on-disk source via __file__ instead. That is the source compiler.py's rewriter actually operates on, and it stays pristine regardless of any runtime patches. Tests affected: - test_compiler_cross_entropy_lm_head_pattern_present - test_compiler_cross_entropy_find_2_loss_function_signature
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1d8bc08e1f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| return unsloth_fused_ce_loss( | ||
| trainer = None, | ||
| hidden_states = hidden_states, | ||
| lm_head_weight = lm_head.weight, |
There was a problem hiding this comment.
Guard against adapter-wrapped lm_head modules
When lm_head is a PEFT/LoRA, quantized, or otherwise wrapped Linear (for example when users include lm_head in LoRA target_modules), its forward applies adapter deltas/dequantization/hooks on top of the base parameters. This fused path bypasses lm_head(hidden_states) and feeds only lm_head.weight/.bias to F.linear, so the labels branch computes loss without the wrapper behavior and those adapter parameters receive no gradient, while the no-label branch still uses the real head. Please either restrict this fast path to exact plain torch.nn.Linear heads at runtime or fold supported wrapper weights into the fused computation.
Useful? React with 👍 / 👎.
trl 1.x padding_free passes shift_labels=<tensor> through the loss
function. The adapter previously fell back to a materialised-logits
F.cross_entropy in that case, which kept the OOM problem the chunked
kernel was supposed to fix.
Plumb shift_labels through unsloth_fused_ce_loss instead. The outer
UnslothFusedLoss.forward already handles label shifting; when the
caller pre-shifted we just flatten and skip the inner re-shift.
Files:
- cross_entropy_loss.py: unsloth_fused_ce_loss gains shift_labels arg
(default True). Outer adds an else branch that flattens pre-shifted
labels so chunking aligns with hidden_states. The four inner
accumulate_chunk call sites pass False unconditionally now since
the outer always normalises labels.
- forward_adapter.py: drop the F.cross_entropy fallback. Pick (target,
do_shift) based on the shift_labels kwarg and call the fused kernel
with shift_labels=do_shift.
- test_fused_forward_install.py: rename the stale fallback test and
add five fp32-strict numerical checks (atol/rtol=1e-5):
* auto-shift matches F.cross_entropy
* pre-shifted tensor matches F.cross_entropy
* shift_labels=False matches F.cross_entropy
* num_items_in_batch divides correctly
* int and 0-d tensor n_items produce equal loss
Empirical end-to-end checks (10 step Llama-3.2-1B LoRA, max_steps=10):
trl 1.4.0 padding_free=True, fused vs off:
step 1 loss: 1.45730 == 1.45730 (exact)
max delta over 10 steps: 0.003 (bf16 noise)
num_items_in_batch wiring (batch=2, grad_accum=4):
HF passes a scalar tensor, consistent across the 4 micro-batches
in each window. n_items equals sum(non_ignore_labels) - rows in
every window (the per-row causal-shift drop), matching the
post-shift count HF uses for the mean reduction.
27/27 unit tests pass.
* Honor UNSLOTH_RETURN_HIDDEN_STATES / UNSLOTH_RETURN_LOGITS in fused forward The AST-rewritten forward installed by PR #657 only had two branches: labels-not-None (fused CE, EMPTY_LOGITS) and else (real logits, no loss). It silently ignored both env vars that the compiler-rewritten forward in unsloth_zoo/compiler.py honors. For GRPO the compiled forward overrides the AST one so this never mattered in practice, but it left the AST forward behaviourally different from the compiled one and not safe to rely on standalone. Expand the rewrite template to the same three-branch shape as the compiled forward: 1. UNSLOTH_RETURN_HIDDEN_STATES=1 -> hidden_states in the logits slot, no lm_head matmul, no loss. GRPO's hidden-states fast path. 2. labels is not None -> fused CE for loss; logits = EMPTY_LOGITS unless UNSLOTH_RETURN_LOGITS=1, in which case the original lm_head expression runs so callers can train + collect logits in one forward. 3. otherwise -> original RHS verbatim, loss = None. forward_install.py: seed the rewritten forward's globals with os so the env-var reads work on classes whose original forward did not import os. Tests: ordering assertion on the rewriter output plus four CUDA-gated behaviour tests covering each branch and the priority of return-hidden over return-logits when both are set. * Drop UNSLOTH_RETURN_HIDDEN_STATES handling from AST forward The hidden-states fast path is owned by the compiler-rewritten forward in unsloth_zoo/compiler.py, which already overrides the AST forward for every *ForCausalLM class that GRPO actually runs on. Honoring the env var in the AST forward as well was defence-in-depth that nobody hits. Keep the UNSLOTH_RETURN_LOGITS opt-in (closes a real gap: lets callers collect real logits + train via fused CE in one forward). Template now goes back to two top-level branches with a nested if for the logits opt-in: if labels is not None: <fused CE> if UNSLOTH_RETURN_LOGITS == '1': logits = <original RHS> else: logits = EMPTY_LOGITS else: logits = <original RHS> loss = None Tests trimmed to match (29 passed). The ns.setdefault('os', os) seed in forward_install.py stays -- the UNSLOTH_RETURN_LOGITS read still needs os available in the rewritten forward's globals. * Avoid double lm_head matmul on UNSLOTH_RETURN_LOGITS=1 path Previous shape called both unsloth_fused_lm_head_loss (which chunks the lm_head matmul internally to compute CE) and self.<head>(<hidden>) (the full matmul) when the opt-in env var was set. Two matmuls for one materialised tensor. New shape splits the labels branch into two paths and picks the right loss path for each: if labels is not None: if UNSLOTH_RETURN_LOGITS == '1': logits = <original RHS> # one matmul loss = self.loss_function(logits, labels, # same logits vocab_size=V, ...) else: loss = unsloth_fused_lm_head_loss(...) # chunked, fused logits = EMPTY_LOGITS else: logits = <original RHS> loss = None The opt-in path now routes through the model's own self.loss_function on the already-materialised logits. Matches HF's standard CausalLM loss shape and the conditional in unsloth_zoo/compiler.py:2074. Tests assert single-matmul + single-self.loss_function on the opt-in path; numerical equivalence holds bit-identically on the toy in this sim (5.003798 vs 5.003798).
trl 1.x padding_free passes shift_labels=<tensor> through the loss
function. The adapter previously fell back to a materialised-logits
F.cross_entropy in that case, which kept the OOM problem the chunked
kernel was supposed to fix.
Plumb shift_labels through unsloth_fused_ce_loss instead. The outer
UnslothFusedLoss.forward already handles label shifting; when the
caller pre-shifted we just flatten and skip the inner re-shift.
Files:
- cross_entropy_loss.py: unsloth_fused_ce_loss gains shift_labels arg
(default True). Outer adds an else branch that flattens pre-shifted
labels so chunking aligns with hidden_states. The four inner
accumulate_chunk call sites pass False unconditionally now since
the outer always normalises labels.
- forward_adapter.py: drop the F.cross_entropy fallback. Pick (target,
do_shift) based on the shift_labels kwarg and call the fused kernel
with shift_labels=do_shift.
- test_fused_forward_install.py: rename the stale fallback test and
add five fp32-strict numerical checks (atol/rtol=1e-5):
* auto-shift matches F.cross_entropy
* pre-shifted tensor matches F.cross_entropy
* shift_labels=False matches F.cross_entropy
* num_items_in_batch divides correctly
* int and 0-d tensor n_items produce equal loss
Empirical end-to-end checks (10 step Llama-3.2-1B LoRA, max_steps=10):
trl 1.4.0 padding_free=True, fused vs off:
step 1 loss: 1.45730 == 1.45730 (exact)
max delta over 10 steps: 0.003 (bf16 noise)
num_items_in_batch wiring (batch=2, grad_accum=4):
HF passes a scalar tensor, consistent across the 4 micro-batches
in each window. n_items equals sum(non_ignore_labels) - rows in
every window (the per-row causal-shift drop), matching the
post-shift count HF uses for the mean reduction.
27/27 unit tests pass.
* Honor UNSLOTH_RETURN_HIDDEN_STATES / UNSLOTH_RETURN_LOGITS in fused forward The AST-rewritten forward installed by PR unslothai#657 only had two branches: labels-not-None (fused CE, EMPTY_LOGITS) and else (real logits, no loss). It silently ignored both env vars that the compiler-rewritten forward in unsloth_zoo/compiler.py honors. For GRPO the compiled forward overrides the AST one so this never mattered in practice, but it left the AST forward behaviourally different from the compiled one and not safe to rely on standalone. Expand the rewrite template to the same three-branch shape as the compiled forward: 1. UNSLOTH_RETURN_HIDDEN_STATES=1 -> hidden_states in the logits slot, no lm_head matmul, no loss. GRPO's hidden-states fast path. 2. labels is not None -> fused CE for loss; logits = EMPTY_LOGITS unless UNSLOTH_RETURN_LOGITS=1, in which case the original lm_head expression runs so callers can train + collect logits in one forward. 3. otherwise -> original RHS verbatim, loss = None. forward_install.py: seed the rewritten forward's globals with os so the env-var reads work on classes whose original forward did not import os. Tests: ordering assertion on the rewriter output plus four CUDA-gated behaviour tests covering each branch and the priority of return-hidden over return-logits when both are set. * Drop UNSLOTH_RETURN_HIDDEN_STATES handling from AST forward The hidden-states fast path is owned by the compiler-rewritten forward in unsloth_zoo/compiler.py, which already overrides the AST forward for every *ForCausalLM class that GRPO actually runs on. Honoring the env var in the AST forward as well was defence-in-depth that nobody hits. Keep the UNSLOTH_RETURN_LOGITS opt-in (closes a real gap: lets callers collect real logits + train via fused CE in one forward). Template now goes back to two top-level branches with a nested if for the logits opt-in: if labels is not None: <fused CE> if UNSLOTH_RETURN_LOGITS == '1': logits = <original RHS> else: logits = EMPTY_LOGITS else: logits = <original RHS> loss = None Tests trimmed to match (29 passed). The ns.setdefault('os', os) seed in forward_install.py stays -- the UNSLOTH_RETURN_LOGITS read still needs os available in the rewritten forward's globals. * Avoid double lm_head matmul on UNSLOTH_RETURN_LOGITS=1 path Previous shape called both unsloth_fused_lm_head_loss (which chunks the lm_head matmul internally to compute CE) and self.<head>(<hidden>) (the full matmul) when the opt-in env var was set. Two matmuls for one materialised tensor. New shape splits the labels branch into two paths and picks the right loss path for each: if labels is not None: if UNSLOTH_RETURN_LOGITS == '1': logits = <original RHS> # one matmul loss = self.loss_function(logits, labels, # same logits vocab_size=V, ...) else: loss = unsloth_fused_lm_head_loss(...) # chunked, fused logits = EMPTY_LOGITS else: logits = <original RHS> loss = None The opt-in path now routes through the model's own self.loss_function on the already-materialised logits. Matches HF's standard CausalLM loss shape and the conditional in unsloth_zoo/compiler.py:2074. Tests assert single-matmul + single-self.loss_function on the opt-in path; numerical equivalence holds bit-identically on the toy in this sim (5.003798 vs 5.003798).
Summary
Opt-in (
UNSLOTH_FUSED_FORWARD=1) auto-installer that rewrites the canonical lm_head +self.loss_functiontriplet on every transformers*ForCausalLM/*ForConditionalGenerationwhose forward matches the shape used from transformers 4.56 onwards. Skippinglogits.float()over(seq_len x vocab_size)avoids the OOM that surfaced in unslothai/unsloth#5441 and shaves the bf16 logits tensor as well.Layers
unsloth_zoo/fused_losses/forward_adapter.pyMaps the HF
self.loss_function(logits=..., labels=..., vocab_size=..., **kwargs)calling convention ontounsloth_fused_ce_loss. Popsnum_items_in_batch->n_items, threadsignore_index/label_smoothing/logit_softcapping/logit_scale_multiply/logit_scale_divide, and falls back to a stock CE if the caller passes a pre-shiftedshift_labelstensor (unsupported by the chunked kernel today).unsloth_zoo/fused_losses/ast_rewriter.pyNodeTransformer that recognises the canonical triplet:
and rewrites it to call
unsloth_fused_lm_head_loss(<HIDDEN_EXPR>, self.<HEAD>, labels, ...). Tolerates keyword vs positionalvocab_size,.float()/[slice]chains around the lm_head call, and detects logits re-binding (e.g. Cohere'slogits = logits * self.logit_scale) as a refuse signal so we never produce a broken forward.unsloth_zoo/fused_losses/forward_install.pyTwo-tier installer: (1) hash-allowlist fast path via
register_canonical(hash, forward_fn)(for future hand-written canonical forwards); (2) AST triplet rewrite. Driven by a meta-path import hook that interceptstransformers.models.<X>.modeling_<X>imports and patches eligible classes as their module loads. Soft floor at transformers >= 4.56.audit()returns a JSON-safe dict of patched / unmatched / failed classes for observability.Kernel updates
unsloth_zoo/fused_losses/cross_entropy_loss.pycompute_fused_ce_loss+UnslothFusedLoss.forwardnow threadignore_index(default-100) into the label-shift step and the innerF.cross_entropycall.compute_fused_ce_lossalso acceptslabel_smoothing. Matches HFForCausalLMLosssemantics so callers that override either no longer silently regress. (logit_softcapping,logit_scale_multiply,logit_scale_dividewere already supported.)Test plan
tests/test_fused_forward_install.py:vocab_size,.float()wrapper. Declines non-canonical, declines on logits rebinding.install_for_class: noop when disabled, skips ineligible names, patches canonical, idempotent, function-override fast path,audit()snapshot.ignore_indexandlabel_smoothingkwargs.unsloth/Llama-3.2-1B-Instruct+yahma/alpaca-cleaned, seed 3407,max_steps=10:Step 1 loss and grad norm are bitwise identical. Across the run: max
|loss delta|= 0.005, max|grad_norm delta|= 0.025 - both within bf16 -> fp32 chunked-CE rounding noise.audit()after import with the flag on (Llama / Qwen3 / Mistral / Gemma3 / DeepseekV3 / Qwen3MoE / Bloom / FalconH1 / Mllama / Csm / Lfm2Vl / Qwen3VLMoe and 7 more): 19 classes patched, 0 failed, 6 unmatched (Cohere, Gemma3 VLM heads, GraniteMoeHybrid, CsmDepthDecoder - all expected outliers; LOSS_MAPPING patch in Patch every LOSS_MAPPING key aliased to ForCausalLMLoss #656 backstops them).Activation
Off by default. Set
UNSLOTH_FUSED_FORWARD=1to opt in. When on, fused install runs atimport unsloth_zoo; new transformers modeling modules imported afterwards are patched via a meta-path hook.from unsloth_zoo.fused_losses import audit; audit()dumps the patched / unmatched / failed registry for debugging.Related: unslothai/unsloth#5441, #656.